import argparse

from lib.trainer import TrainConfig, Trainer
from lib.utils import parse_config_file, get_input_dataset
import random
import numpy as np
import torch


if __name__ == '__main__':
    dataset_name = get_input_dataset()

    config = parse_config_file(dataset_name)
    trainer = Trainer(TrainConfig(**config))

    print("Start training...")
    trainer.train()
    print("Training finished!")
